package org.springframework.security.oauth2.client.endpoint;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

/**
 * 刷新token请求参数提取器<br/>
 * 注：支持PKCE模式（NONE）刷新token时添加client_id参数
 *
 * @author luohq
 * @date 2022-03-11
 * @see OAuth2PkceRefreshTokenGrantRequestEntityConverter
 */
public class OAuth2PkceRefreshTokenGrantRequestEntityConverter
		extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2RefreshTokenGrantRequest> {

	@Override
	protected MultiValueMap<String, String> createParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
		ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration();
		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
		parameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue());
		parameters.add(OAuth2ParameterNames.REFRESH_TOKEN, refreshTokenGrantRequest.getRefreshToken().getTokenValue());
		if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) {
			parameters.add(OAuth2ParameterNames.SCOPE,
					StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " "));
		}
		if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())
				|| ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
			parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
			parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
		}

		/** ====================== 扩展 - 开始 ===================== */
		//支持PKCE模式（NONE）刷新token时添加client_id参数
		if (ClientAuthenticationMethod.NONE.equals(clientRegistration.getClientAuthenticationMethod())) {
			parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
		}
		/** ====================== 扩展 - 结束 ===================== */
		return parameters;
	}

}